# -*- coding: utf-8 -*-
# The script was originally run on a super computing node with sbatch file attached
# Two options are possible
# 1. to run on a super computing node, set up conda environment 
# before running by the following command
# conda env create -f environment.yml
# 2. to run on Google Colab, run it after running train_bert_ac.py


import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import pandas as pd
import numpy as np


from tabulate import tabulate
from tqdm import trange
import random


# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

## Prepare training data
df = pd.read_csv("./ac_coded1.csv")

df2 = pd.read_csv("./ac_coded2.csv")

df2 = df2[['content','category']]

## Merge data

df = df[['text', 'category']]

# rename dataframe columns
df2 = df2.rename(columns={'content': "text", "category": "category"}, errors="raise")

# concatenate dataframes
frames = [df, df2]
result = pd.concat(frames)

# drop duplicate text
result = result.drop_duplicates("text").reset_index(drop=True)
result['category'].value_counts()

text = result.text.values
labels = result.category.values

## Proprocess data for BERT training

## Get label dictionary for future purposes

possible_labels = result['category'].unique()

label_dict = {}
for index, possible_label in enumerate(possible_labels):
    label_dict[possible_label] = index

## Prepare prediction data
# df4 = pd.read_csv("./sinanews_to_be_predicted.csv")   ## replace the name with the prediction file
df4 = pd.read_csv("./sinaclick_news.csv")
df4 = df4[['autoid','title']]


# Report the number of sentences.
print('Number of test sentences: {:,}\n'.format(df4.shape[0]))

# Create sentence and label lists
sentences = df4.title.values
timeuids = df4.autoid.values

# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids = []
attention_masks = []

# For every sentence...
for sent in sentences:
    # `encode_plus` will:
    #   (1) Tokenize the sentence.
    #   (2) Prepend the `[CLS]` token to the start.
    #   (3) Append the `[SEP]` token to the end.
    #   (4) Map tokens to their IDs.
    #   (5) Pad or truncate the sentence to `max_length`
    #   (6) Create attention masks for [PAD] tokens.

    encoded_dict = tokenizer.encode_plus(
                        sent,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 32,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                        truncation=True
                   )
    
    
    # Add the encoded sentence to the list.    
    input_ids.append(encoded_dict['input_ids'])
    
    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks.append(encoded_dict['attention_mask'])

# Convert the lists into tensors.
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)


# Set the batch size.  
batch_size = 32  

# Create the DataLoader.
prediction_data = TensorDataset(input_ids, attention_masks)
prediction_sampler = SequentialSampler(prediction_data)
prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)

# Prediction on test set
def evaluate_pred2(dataloader_val):

  print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))

  # Put model in evaluation mode
  model.eval()

  # Tracking variables 
  predictions , true_labels = [], []

  # Predict 
  for batch in prediction_dataloader:
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask = batch
    
    # Telling the model not to compute or store gradients, saving memory and 
    # speeding up prediction
    with torch.no_grad():
        # Forward pass, calculate logit predictions.
        result = model(b_input_ids, 
                      token_type_ids=None, 
                      attention_mask=b_input_mask,
                      return_dict=True)

    logits = result.logits

    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()
    # label_ids = b_labels.to('cpu').numpy()
    
    # Store predictions and true labels
    predictions.append(logits)
    # true_labels.append(label_ids)

  print('    DONE.')
  return predictions


## Evaluate to the prediction

torch.cuda.empty_cache()


model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# directory that saves model
model.load_state_dict(torch.load('./saved_model/finetuned_BERT_epoch_10.model', map_location=torch.device('cpu')))


predictions2 = evaluate_pred2(prediction_dataloader)

# Combine the results across all batches. 
flat_predictions = np.concatenate(predictions2, axis=0)

# For each sample, pick the label (0 or 1) with the higher score.
flat_predictions = np.argmax(flat_predictions, axis=1).flatten()

# Combine the prediction with post id
all = np.transpose(np.vstack((timeuids, flat_predictions)))
# print(all.shape)

# Add heading
all = pd.DataFrame(all, columns=["autoid", "p"])

## Write results to cvs
all.to_csv("./sinanewsprediction_results.csv", sep='\t')